# -*- coding:utf-8 -*-

from __future__ import division
import tensorflow as tf
from element.dnn_module import fc_layer
from element.cnn_module import conv2d_def, fc_layer_from_conv2d, max_pool_def
from config.glob.global_pool import global_pool

"""
LeNet5
"""


def net(self):
    with tf.name_scope('conv2d_1'):
        layer1 = conv2d_def(self.xs, out_channel=6, kernel_size=(5, 5), padding='SAME')
    with tf.name_scope('pool_1'):
        layer2 = max_pool_def(layer1, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    with tf.name_scope('conv2d_2'):
        layer3 = conv2d_def(layer2, out_channel=16, kernel_size=(5, 5),  padding='SAME')
    with tf.name_scope('pool_2'):
        layer4 = max_pool_def(layer3, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
    with tf.name_scope('fc_1'):
        layer5 = fc_layer_from_conv2d(layer4, 120, activation_func=tf.nn.relu)
    with tf.name_scope('fc_2'):
        layer6 = fc_layer(layer5, 84, activation_func=tf.nn.relu)
    if global_pool.config.net.dropout.use:
        dropout_rate = global_pool.config.net.dropout.rate  # dropout rate
        with tf.name_scope('dropout'):
            layer6 = tf.nn.dropout(layer6, rate=dropout_rate)
    with tf.name_scope('layer7_output'):
        self.y_pred = fc_layer(layer6, 10, activation_func=tf.nn.softmax)
